import pandas as pd
from data_resource.data_bases import engine


def stockPool_raw(pool_type='csi800_1000', include_industry=False):
    """获取初始股票池"""
    assert pool_type in ['csi800_1000', ]

    if pool_type == 'csi800_1000':
        _sql = """
        select a.con_code as code, b.l1_name as industry
        from quant_research.index_constituent as a
        left join quant_research.sw_industry_constituent as b
            on a.con_code=b.ts_code 
        where trade_date=(
            select max(trade_date) from quant_research.index_constituent
        ) and index_code in ('000852.SH', '000906.SH')
        """
        pool = pd.read_sql(_sql, engine)
        if include_industry:
            return pool
        else:
            return pool['code'].tolist()
